import os
import math
import json
import pickle
import torch.nn as nn
from functools import reduce
import pandas as pd
from common.flatten_object import flatten_object


class ObjectDict(dict):
    def __getattr__(self, name):
        if name not in self:
            raise AttributeError('{}: no such attribute: {}'.format(self.__class__.__name__, name))
        return self[name]

    def __setattr__(self, name, value):
        self[name] = value

    def __delattr__(self, name):
        if name not in self:
            raise AttributeError('{}: no such attribute: {}'.format(self.__class__.__name__, name))
        del self[name]


class ModelAnalyzer:

    class Info:
        __version__ = 1.0

        def __init__(self, key, module):
            self.key = key
            self.module_type = type(module).__name__
            self.fwd_in = ObjectDict({})
            self.fwd_out = ObjectDict({})
            self.grad_in = ObjectDict({})
            self.grad_out = ObjectDict({})
            self.macs = ObjectDict({})

    def __init__(self, model, model_name, batch_size, train_step, out_path=None):
        self.model = model
        self.model_name = model_name
        self.batch_size = batch_size
        self.train_step = train_step
        self.path = self.create_out_path(out_path)
        self.modules = self.get_modules()
        self.stats = {}
        self.collect_info()
        self.analyze()

    def info_filename(self):
        ver = self.Info.__version__
        return os.path.join(self.path, 'info_{}_v{}_b{}.pkl'.format(self.model_name, ver, self.batch_size))

    def save_info(self):
        filename = self.info_filename()
        with open(filename, 'wb') as f:
            pickle.dump(self.stats, f)

    def restore_info(self):
        filename = self.info_filename()
        if os.path.exists(filename):
            with open(filename, 'rb') as f:
                self.stats = pickle.load(f)
                assert self.stats is not None
            return True
        return False

    def collect_info(self):
        def set_module_key_attribute():
            for key, m in self.modules.items():
                setattr(m, '_module_key', key)

        def del_module_key_attribute():
            for key, m in self.modules.items():
                delattr(m, '_module_key')

        if not self.restore_info():
            set_module_key_attribute()
            hooks = self.register_hooks()
            self.train_step()
            self.save_info()
            del_module_key_attribute()
            map(lambda h: h.remove(), hooks)

    @staticmethod
    def create_out_path(path):
        if path is None:
            path = os.path.join(os.getcwd(), 'results', 'analysis')
        if not os.path.exists(path):
            os.makedirs(path)
        return path

    def get_modules(self):
        def is_required(m):
            return any([isinstance(m, module) for module in include_modules])

        include_modules = [nn.Conv2d, nn.Linear]
        modules = {k: m for k, m in self.model.named_modules() if is_required(m)}
        return modules

    def register_hooks(self):
        hooks = []
        for m in self.modules.values():
            hooks.append(m.register_forward_hook(self.forward_hook_cb))
            hooks.append(m.register_backward_hook(self.backward_hook_cb))
        return hooks

    def get_stat_entry(self, module):
        key = getattr(module, '_module_key', None)
        assert key is not None
        entry = self.stats.get(key, None)
        if entry is None:
            entry = ModelAnalyzer.Info(key, module)
            self.stats[key] = entry
        return entry

    def forward_hook_cb(self, module, inp, out):
        entry = self.get_stat_entry(module)
        entry.fwd_in.shape = tuple(inp[0].shape)
        entry.fwd_in.numel = inp[0].numel()
        entry.fwd_out.shape = tuple(out.shape)
        entry.fwd_out.numel = out.numel()

    conv2d_grads = ['features', 'weights', 'bias']
    linear_grads = ['bias', 'features', 'weights']

    def backward_hook_cb(self, module, grad_out, grad_in):
        def grad_wrt(mod, index):
            assert index <= 2
            grads = ModelAnalyzer.conv2d_grads if isinstance(mod, nn.Conv2d) else ModelAnalyzer.linear_grads
            return grads[index]

        entry = self.get_stat_entry(module)
        entry.grad_in.shape = tuple(grad_in[0].shape)
        entry.grad_in.numel = grad_in[0].numel()
        entry.grad_out = ObjectDict({})
        for i, grad in enumerate(grad_out):
            wrt = grad_wrt(module, i)
            data = {
                'shape': None if grad is None else tuple(grad.shape),
                'numel': 0 if grad is None else grad.numel()
            }
            entry.grad_out[wrt] = ObjectDict(**data)

    @staticmethod
    def calc_conv2d_macs(out_shape, kw, kh, in_channels, groups):
        out_size = reduce((lambda x, y: x * y), out_shape)
        return math.ceil(out_size * kw * kh * in_channels / groups)

    def estimate_macs_module_conv2d(self, key):
        def fwd_macs():
            kernel_size = math.ceil(kh * kw * ch_in / groups)
            return info.fwd_out.numel * kernel_size

        def grad_wrt_features_macs():
            if info.grad_out.features.shape is None:
                return 0
            yb, _, yh, yw = info.fwd_out.shape
            kernel_size = math.ceil(kh * kw * ch_out / groups)
            return yb * ch_in * yh * yw * kernel_size

        def grad_wrt_weight_macs():
            yb, _, yh, yw = info.fwd_out.shape
            kernel_size = yb * yh * yw
            return kh * kw * ch_in * ch_out * kernel_size

        def grad_wrt_bias_macs():
            yb, _, yh, yw = info.fwd_out.shape
            return yb * yh * yw

        module = self.modules[key]
        info = self.stats[key]
        kh, kw = module.kernel_size
        ch_in, ch_out = module.in_channels, module.out_channels
        groups = module.groups

        fwd_macs = fwd_macs()
        grad_f_macs = grad_wrt_features_macs()
        grad_w_macs = grad_wrt_weight_macs()
        grad_b_macs = grad_wrt_bias_macs()
        return fwd_macs, grad_f_macs, grad_w_macs, grad_b_macs

    def estimate_macs_module_linear(self, key):
        module = self.modules[key]
        info = self.stats[key]
        fwd_macs = info.fwd_in.shape[0] * module.in_features * module.out_features
        grad_b_macs = module.out_features
        return fwd_macs, fwd_macs, fwd_macs, grad_b_macs

    def estimate_macs(self):
        estimators = {
            'Conv2d': self.estimate_macs_module_conv2d,
            'ConvBn2d': self.estimate_macs_module_conv2d,
            'ConvBnReLU2d': self.estimate_macs_module_conv2d,
            'Conv2d_BF16': self.estimate_macs_module_conv2d,
            'Linear': self.estimate_macs_module_linear,
            'LinearReLU': self.estimate_macs_module_linear,
            'Linear_BF16': self.estimate_macs_module_linear,
        }

        for key, info in self.stats.items():
            if not info.macs:
                estimator = estimators.get(info.module_type, None)
                assert estimator is not None, 'No MAC estimator for {}'.format(info.module_type)
                fwd, grad_f, grad_w, grad_b = estimator(key)
                info.macs.fwd = fwd
                info.macs.grad_features = grad_f
                info.macs.grad_weights = grad_w
                info.macs.grad_bias = grad_b

    def analyze(self):
        self.estimate_macs()

    def to_csv(self, filename=None):
        if filename is None:
            filename = os.path.join(self.path, 'results_{}_b{}.csv'.format(self.model_name, self.batch_size))
        stats = {k: flatten_object(vars(e)) for k, e in self.stats.items()}
        df = pd.DataFrame.from_dict(stats, orient='index')
        df.to_csv(filename, index=False)

    def to_json(self, filename=None):
        if filename is None:
            filename = os.path.join(self.path, 'results_{}_b{}.json'.format(self.model_name, self.batch_size))
        stats = {k: vars(e) for k, e in self.stats.items()}
        with open(filename, 'w') as fp:
            json.dump(stats, fp, indent=4, sort_keys=True)

    def to_file(self, formatter='json', filename=None):
        assert formatter in ('json', 'csv')
        if formatter == 'json':
            self.to_json(filename)
        else:
            self.to_csv(filename)

    def to_dict(self):
        stats = {k: vars(e) for k, e in self.stats.items()}
        return stats
